// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package integrationtests

import (
	"context"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/pingcap/tidb/pkg/disttask/framework/handle"
	mockexecute "github.com/pingcap/tidb/pkg/disttask/framework/mock/execute"
	"github.com/pingcap/tidb/pkg/disttask/framework/proto"
	"github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute"
	"github.com/pingcap/tidb/pkg/disttask/framework/testutil"
	"github.com/pingcap/tidb/pkg/sessionctx"
	"github.com/pingcap/tidb/pkg/testkit/testfailpoint"
	"github.com/pingcap/tidb/pkg/util/sqlexec"
	"github.com/stretchr/testify/require"
	"go.uber.org/mock/gomock"
)

type collectedRuntimeInfo struct {
	currTaskMeta        atomic.Pointer[[]byte]
	currTaskConcurrency atomic.Int64
	activeSubtaskCount  atomic.Int64
	subtaskInfos        []subtaskRuntimeInfo
}

type subtaskRuntimeInfo struct {
	Step        proto.Step
	Concurrency int
}

func prepareModifyTaskTest(t *testing.T, nodeCount int) (*testutil.TestDXFContext, *collectedRuntimeInfo, chan struct{}, chan struct{}, *atomic.Bool) {
	c := testutil.NewTestDXFContext(t, nodeCount, 16, true)
	stepInfos := []testutil.StepInfo{
		{Step: proto.StepOne, SubtaskCnt: 2},
		{Step: proto.StepTwo, SubtaskCnt: 3},
	}
	schedulerExt := testutil.GetMockSchedulerExt(c.MockCtrl, testutil.SchedulerInfo{
		AllErrorRetryable: true,
		StepInfos:         stepInfos,
	})
	subtaskCh := make(chan struct{})
	runtimeInfo := &collectedRuntimeInfo{}
	var (
		testModifyWhenSubtaskRun atomic.Bool
		modifyWaitCh             = make(chan struct{})
	)

	var mu sync.Mutex
	runSubtaskFn := func(ctx context.Context, subtask *proto.Subtask) error {
		runtimeInfo.activeSubtaskCount.Add(1)
		defer runtimeInfo.activeSubtaskCount.Add(-1)
		if testModifyWhenSubtaskRun.Load() {
			<-modifyWaitCh
			<-modifyWaitCh
		}
		select {
		case <-subtaskCh:
			mu.Lock()
			runtimeInfo.subtaskInfos = append(runtimeInfo.subtaskInfos, subtaskRuntimeInfo{
				Step:        subtask.Step,
				Concurrency: subtask.Concurrency,
			})
			mu.Unlock()
		case <-ctx.Done():
			return ctx.Err()
		}
		return nil
	}

	executorExt := testutil.GetCommonTaskExecutorExt(c.MockCtrl, func(task *proto.Task) (execute.StepExecutor, error) {
		runtimeInfo.currTaskMeta.Store(&task.Meta)
		runtimeInfo.currTaskConcurrency.Store(int64(task.Concurrency))
		executor := mockexecute.NewMockStepExecutor(c.MockCtrl)
		executor.EXPECT().Init(gomock.Any()).Return(nil).AnyTimes()
		executor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).DoAndReturn(runSubtaskFn).AnyTimes()
		executor.EXPECT().GetStep().Return(task.Step).AnyTimes()
		executor.EXPECT().SetResource(gomock.Any()).AnyTimes()
		executor.EXPECT().Cleanup(gomock.Any()).Return(nil).AnyTimes()
		executor.EXPECT().RealtimeSummary().Return(nil).AnyTimes()
		executor.EXPECT().TaskMetaModified(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, newMeta []byte) error {
			runtimeInfo.currTaskMeta.Store(&newMeta)
			return nil
		}).AnyTimes()
		executor.EXPECT().ResourceModified(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, newResource *proto.StepResource) error {
			runtimeInfo.currTaskConcurrency.Store(newResource.CPU.Capacity())
			return nil
		}).AnyTimes()
		return executor, nil
	})
	testutil.RegisterExampleTask(t, schedulerExt, executorExt, testutil.GetCommonCleanUpRoutine(c.MockCtrl))
	return c, runtimeInfo, subtaskCh, modifyWaitCh, &testModifyWhenSubtaskRun
}

func TestModifyTaskConcurrencyAndMeta(t *testing.T) {
	c, runtimeInfo, subtaskCh, modifyWaitCh, testModifyWhenSubtaskRun := prepareModifyTaskTest(t, 1)
	resetRuntimeInfoFn := func() {
		*runtimeInfo = collectedRuntimeInfo{}
	}
	scope := handle.GetTargetScope()
	t.Run("modify pending task concurrency", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		var once sync.Once
		modifySyncCh := make(chan struct{})
		var theTask *proto.Task
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() {
			once.Do(func() {
				task, err := handle.SubmitTask(c.Ctx, "k1", proto.TaskTypeExample, "", 3, scope, 0, []byte("init"))
				require.NoError(t, err)
				require.Equal(t, 3, task.Concurrency)
				require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{
					PrevState: proto.TaskStatePending,
					Modifications: []proto.Modification{
						{Type: proto.ModifyConcurrency, To: 7},
					},
				}))
				theTask = task
				gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, theTask.ID)
				require.NoError(t, err)
				require.Equal(t, proto.TaskStateModifying, gotTask.State)
				require.Equal(t, 3, gotTask.Concurrency)
				<-modifySyncCh
			})
		})
		modifySyncCh <- struct{}{}
		// finish subtasks
		for range 5 {
			subtaskCh <- struct{}{}
		}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
		require.EqualValues(t, []subtaskRuntimeInfo{
			{Step: proto.StepOne, Concurrency: 7},
			{Step: proto.StepOne, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
		}, runtimeInfo.subtaskInfos)
	})

	t.Run("modify running task concurrency at step two", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		var once sync.Once
		modifySyncCh := make(chan struct{})
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeRefreshTask", func(task *proto.Task) {
			if task.State != proto.TaskStateRunning && task.Step != proto.StepTwo {
				return
			}
			once.Do(func() {
				require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{
					PrevState: proto.TaskStateRunning,
					Modifications: []proto.Modification{
						{Type: proto.ModifyConcurrency, To: 7},
					},
				}))
				<-modifySyncCh
			})
		})
		task, err := handle.SubmitTask(c.Ctx, "k2", proto.TaskTypeExample, "", 3, scope, 0, nil)
		require.NoError(t, err)
		require.Equal(t, 3, task.Concurrency)
		// finish StepOne
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		// wait task move to 'modifying' state
		modifySyncCh <- struct{}{}
		// wait task move back to 'running' state
		require.Eventually(t, func() bool {
			gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, task.ID)
			require.NoError(t, err2)
			return gotTask.State == proto.TaskStateRunning
		}, 10*time.Second, 100*time.Millisecond)
		// finish StepTwo
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
		require.EqualValues(t, []subtaskRuntimeInfo{
			{Step: proto.StepOne, Concurrency: 3},
			{Step: proto.StepOne, Concurrency: 3},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
		}, runtimeInfo.subtaskInfos)
	})

	t.Run("modify running task concurrency at second subtask of step two", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/beforeGetTaskByIDInRun",
			func(taskID int64) {
				if len(runtimeInfo.subtaskInfos) == 3 {
					require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, taskID, &proto.ModifyParam{
						PrevState: proto.TaskStateRunning,
						Modifications: []proto.Modification{
							{Type: proto.ModifyConcurrency, To: 7},
						},
					}))
					// wait task move back to 'running' state
					require.Eventually(t, func() bool {
						gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, taskID)
						require.NoError(t, err2)
						return gotTask.State == proto.TaskStateRunning
					}, 10*time.Second, 100*time.Millisecond)
				}
			},
		)
		task, err := handle.SubmitTask(c.Ctx, "k2-2", proto.TaskTypeExample, "", 3, scope, 0, nil)
		require.NoError(t, err)
		require.Equal(t, 3, task.Concurrency)
		for range 5 {
			subtaskCh <- struct{}{}
		}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
		require.EqualValues(t, []subtaskRuntimeInfo{
			{Step: proto.StepOne, Concurrency: 3},
			{Step: proto.StepOne, Concurrency: 3},
			{Step: proto.StepTwo, Concurrency: 3},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
		}, runtimeInfo.subtaskInfos)
	})

	t.Run("modify paused task concurrency", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		var once sync.Once
		syncCh := make(chan struct{})
		var theTask *proto.Task
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() {
			once.Do(func() {
				task, err := handle.SubmitTask(c.Ctx, "k3", proto.TaskTypeExample, "", 3, scope, 0, nil)
				require.NoError(t, err)
				require.Equal(t, 3, task.Concurrency)
				found, err := c.TaskMgr.PauseTask(c.Ctx, task.Key)
				require.NoError(t, err)
				require.True(t, found)
				theTask = task
				<-syncCh
			})
		})
		syncCh <- struct{}{}
		taskBase := testutil.WaitTaskDoneOrPaused(c.Ctx, t, theTask.Key)
		require.Equal(t, proto.TaskStatePaused, taskBase.State)
		require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, theTask.ID, &proto.ModifyParam{
			PrevState: proto.TaskStatePaused,
			Modifications: []proto.Modification{
				{Type: proto.ModifyConcurrency, To: 7},
			},
		}))
		taskBase = testutil.WaitTaskDoneOrPaused(c.Ctx, t, theTask.Key)
		require.Equal(t, proto.TaskStatePaused, taskBase.State)
		found, err := c.TaskMgr.ResumeTask(c.Ctx, theTask.Key)
		require.NoError(t, err)
		require.True(t, found)
		// finish subtasks
		for range 5 {
			subtaskCh <- struct{}{}
		}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
		require.EqualValues(t, []subtaskRuntimeInfo{
			{Step: proto.StepOne, Concurrency: 7},
			{Step: proto.StepOne, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
		}, runtimeInfo.subtaskInfos)
	})

	t.Run("modify pending task concurrency, but other owner already done it", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		var once sync.Once
		modifySyncCh := make(chan struct{})
		var theTask *proto.Task
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() {
			once.Do(func() {
				task, err := handle.SubmitTask(c.Ctx, "k4", proto.TaskTypeExample, "", 3, scope, 0, nil)
				require.NoError(t, err)
				require.Equal(t, 3, task.Concurrency)
				require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{
					PrevState: proto.TaskStatePending,
					Modifications: []proto.Modification{
						{Type: proto.ModifyConcurrency, To: 7},
					},
				}))
				theTask = task
				gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, theTask.ID)
				require.NoError(t, err)
				require.Equal(t, proto.TaskStateModifying, gotTask.State)
				require.Equal(t, 3, gotTask.Concurrency)
			})
		})
		var onceForRefresh sync.Once
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/afterRefreshTask",
			func(task *proto.Task) {
				onceForRefresh.Do(func() {
					require.Equal(t, proto.TaskStateModifying, task.State)
					taskClone := *task
					taskClone.Concurrency = 7
					require.NoError(t, c.TaskMgr.ModifiedTask(c.Ctx, &taskClone))
					gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, task.ID)
					require.NoError(t, err)
					require.Equal(t, proto.TaskStatePending, gotTask.State)
					<-modifySyncCh
				})
			},
		)
		modifySyncCh <- struct{}{}
		// finish subtasks
		for range 5 {
			subtaskCh <- struct{}{}
		}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
		require.EqualValues(t, []subtaskRuntimeInfo{
			{Step: proto.StepOne, Concurrency: 7},
			{Step: proto.StepOne, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
			{Step: proto.StepTwo, Concurrency: 7},
		}, runtimeInfo.subtaskInfos)
	})

	t.Run("modify pending task meta, only check the scheduler part", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		var once sync.Once
		modifySyncCh := make(chan struct{})
		var theTask *proto.Task
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() {
			once.Do(func() {
				task, err := handle.SubmitTask(c.Ctx, "k5", proto.TaskTypeExample, "", 3, scope, 0, []byte("init"))
				require.NoError(t, err)
				require.Equal(t, 3, task.Concurrency)
				require.EqualValues(t, []byte("init"), task.Meta)
				require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{
					PrevState: proto.TaskStatePending,
					Modifications: []proto.Modification{
						{Type: proto.ModifyMaxWriteSpeed, To: 123},
					},
				}))
				theTask = task
				gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, theTask.ID)
				require.NoError(t, err)
				require.Equal(t, proto.TaskStateModifying, gotTask.State)
				require.Equal(t, 3, gotTask.Concurrency)
				<-modifySyncCh
			})
		})
		modifySyncCh <- struct{}{}
		// finish subtasks
		for range 5 {
			subtaskCh <- struct{}{}
		}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
		require.EqualValues(t, []byte("modify_max_write_speed=123"), *runtimeInfo.currTaskMeta.Load())
	})

	t.Run("modify meta and increase concurrency when subtask is running, and apply success", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		defer testModifyWhenSubtaskRun.Store(false)
		testModifyWhenSubtaskRun.Store(true)
		task, err := handle.SubmitTask(c.Ctx, "k6", proto.TaskTypeExample, "", 3, scope, 0, []byte("init"))
		require.NoError(t, err)
		require.Equal(t, 3, task.Concurrency)
		require.EqualValues(t, []byte("init"), task.Meta)
		modifyWaitCh <- struct{}{}
		require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{
			PrevState: proto.TaskStateRunning,
			Modifications: []proto.Modification{
				{Type: proto.ModifyConcurrency, To: 7},
				{Type: proto.ModifyMaxWriteSpeed, To: 123},
			},
		}))
		require.Eventually(t, func() bool {
			return "modify_max_write_speed=123" == string(*runtimeInfo.currTaskMeta.Load()) &&
				runtimeInfo.currTaskConcurrency.Load() == 7
		}, 10*time.Second, 100*time.Millisecond)
		testModifyWhenSubtaskRun.Store(false)
		modifyWaitCh <- struct{}{}
		// finish subtasks
		for range 5 {
			subtaskCh <- struct{}{}
		}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
		require.Equal(t, "modify_max_write_speed=123", string(*runtimeInfo.currTaskMeta.Load()))
		require.Equal(t, int64(7), runtimeInfo.currTaskConcurrency.Load())
	})

	t.Run("modify meta and decrease concurrency when subtask is running, and apply success", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		defer testModifyWhenSubtaskRun.Store(false)
		testModifyWhenSubtaskRun.Store(true)
		task, err := handle.SubmitTask(c.Ctx, "k7", proto.TaskTypeExample, "", 9, scope, 0, []byte("init"))
		require.NoError(t, err)
		require.Equal(t, 9, task.Concurrency)
		require.EqualValues(t, []byte("init"), task.Meta)
		modifyWaitCh <- struct{}{}
		require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{
			PrevState: proto.TaskStateRunning,
			Modifications: []proto.Modification{
				{Type: proto.ModifyConcurrency, To: 5},
				{Type: proto.ModifyMaxWriteSpeed, To: 456},
			},
		}))
		require.Eventually(t, func() bool {
			return "modify_max_write_speed=456" == string(*runtimeInfo.currTaskMeta.Load()) &&
				runtimeInfo.currTaskConcurrency.Load() == 5
		}, 10*time.Second, 100*time.Millisecond)
		testModifyWhenSubtaskRun.Store(false)
		modifyWaitCh <- struct{}{}
		// finish subtasks
		for range 5 {
			subtaskCh <- struct{}{}
		}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
		require.Equal(t, "modify_max_write_speed=456", string(*runtimeInfo.currTaskMeta.Load()))
		require.Equal(t, int64(5), runtimeInfo.currTaskConcurrency.Load())
	})

	t.Run("modify running task max node count", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		var once sync.Once
		modifySyncCh := make(chan struct{})
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeRefreshTask", func(task *proto.Task) {
			if task.State != proto.TaskStateRunning && task.Step != proto.StepOne {
				return
			}
			once.Do(func() {
				require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{
					PrevState: proto.TaskStateRunning,
					Modifications: []proto.Modification{
						{Type: proto.ModifyMaxNodeCount, To: 200},
					},
				}))
				<-modifySyncCh
			})
		})
		task, err := handle.SubmitTask(c.Ctx, "k8", proto.TaskTypeExample, "", 3, scope, 1, nil)
		require.NoError(t, err)
		require.Equal(t, 3, task.Concurrency)
		require.EqualValues(t, 1, task.MaxNodeCount)
		// finish StepOne
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		// wait task move to 'modifying' state
		modifySyncCh <- struct{}{}
		// wait task move back to 'running' state, and modified
		require.Eventually(t, func() bool {
			gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, task.ID)
			require.NoError(t, err2)
			return gotTask.State == proto.TaskStateRunning && gotTask.MaxNodeCount == 200
		}, 10*time.Second, 100*time.Millisecond)
		// finish StepTwo
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
	})
}

func TestModifyTaskMaxNodeCountForSubtaskBalance(t *testing.T) {
	c, runtimeInfo, subtaskCh, _, _ := prepareModifyTaskTest(t, 3)
	resetRuntimeInfoFn := func() {
		*runtimeInfo = collectedRuntimeInfo{}
	}
	scope := handle.GetTargetScope()
	t.Run("modify running task max node count, task can use more node after balance", func(t *testing.T) {
		defer resetRuntimeInfoFn()
		var once sync.Once
		modifySyncCh := make(chan struct{})
		testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeRefreshTask", func(task *proto.Task) {
			if task.State != proto.TaskStateRunning && task.Step != proto.StepOne {
				return
			}
			once.Do(func() {
				<-modifySyncCh
				require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{
					PrevState: proto.TaskStateRunning,
					Modifications: []proto.Modification{
						{Type: proto.ModifyMaxNodeCount, To: 2},
					},
				}))
				<-modifySyncCh
			})
		})
		task, err := handle.SubmitTask(c.Ctx, "k8", proto.TaskTypeExample, "", 3, scope, 1, nil)
		require.NoError(t, err)
		require.Equal(t, 3, task.Concurrency)
		require.EqualValues(t, 1, task.MaxNodeCount)
		// only 1 subtask can be running at the same time
		require.Eventually(t, func() bool {
			return runtimeInfo.activeSubtaskCount.Load() == 1
		}, 10*time.Second, 100*time.Millisecond)
		require.NoError(t, c.TaskMgr.WithNewSession(func(se sessionctx.Context) error {
			rows, err2 := sqlexec.ExecSQL(c.Ctx, se.GetSQLExecutor(), "select count(distinct exec_id) from mysql.tidb_background_subtask where task_key = %?", task.ID)
			require.NoError(t, err2)
			require.Equal(t, 1, len(rows))
			require.EqualValues(t, 1, rows[0].GetInt64(0))
			return nil
		}))
		// task move to 'modifying' state
		modifySyncCh <- struct{}{}
		modifySyncCh <- struct{}{}
		// wait task move back to 'running' state, and modified
		require.Eventually(t, func() bool {
			gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, task.ID)
			require.NoError(t, err2)
			return gotTask.State == proto.TaskStateRunning && gotTask.MaxNodeCount == 2
		}, 10*time.Second, 100*time.Millisecond)
		// now 2 subtask can be running at the same time
		require.Eventually(t, func() bool {
			return runtimeInfo.activeSubtaskCount.Load() == 2
		}, 10*time.Second, 100*time.Millisecond)
		require.NoError(t, c.TaskMgr.WithNewSession(func(se sessionctx.Context) error {
			rows, err2 := sqlexec.ExecSQL(c.Ctx, se.GetSQLExecutor(), "select count(distinct exec_id) from mysql.tidb_background_subtask where task_key = %?", task.ID)
			require.NoError(t, err2)
			require.Equal(t, 1, len(rows))
			require.EqualValues(t, 2, rows[0].GetInt64(0))
			return nil
		}))
		// finish StepOne
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		// finish StepTwo
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		subtaskCh <- struct{}{}
		task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key)
		require.Equal(t, proto.TaskStateSucceed, task2Base.State)
	})
}
